/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "coll_reduce_scatter_mesh_graph_pipeline_executor.h"

namespace hccl {
CollReduceScatterMeshGraphPipelineExecutor::CollReduceScatterMeshGraphPipelineExecutor(
    const HcclDispatcher dispatcher, std::unique_ptr<TopoMatcher> &topoMatcher)
    : CollReduceScatterExecutor(dispatcher, topoMatcher)
{}

void CollReduceScatterMeshGraphPipelineExecutor::ParseParam(const OpParam &param)
{
    tag_ = param.tag;
    aicpuUnfoldMode_ = param.aicpuUnfoldMode;
}

HcclResult CollReduceScatterMeshGraphPipelineExecutor::CalcStreamNum(u32 &streamNum)
{
    streamNum = topoAttr_.deviceNumPerAggregation;
    HCCL_INFO(
        "[CollReduceScatterMeshGraphPipelineExecutor][CalcStreamNum] tag[%s] streamNum[%u]", tag_.c_str(), streamNum);
    return HCCL_SUCCESS;
}

HcclResult CollReduceScatterMeshGraphPipelineExecutor::CalcCommInfo(std::vector<LevelNSubCommTransport> &opTransport)
{
    TransportMemType inputType = TransportMemType::RESERVED;
    TransportMemType outputType = TransportMemType::RESERVED;
    CHK_RET(CalcTransportMemType(inputType, outputType));
    CHK_RET(CalcLevel0CommInfo(inputType, outputType, opTransport));
    CHK_RET(CalcLevel1CommInfo(inputType, outputType, opTransport));
    return HCCL_SUCCESS;
}

HcclResult CollReduceScatterMeshGraphPipelineExecutor::CalcTransportMemType(
    TransportMemType &inputType, TransportMemType &outputType)
{
    inputType = TransportMemType::PARAM_INPUT;
    outputType = TransportMemType::PARAM_OUTPUT;
    HCCL_INFO("[CollReduceScatterMeshGraphPipelineExecutor][CalcTransportMemType]tag[%s] inputType[%d], outputType[%d]",
        tag_.c_str(),
        inputType,
        outputType);
    return HCCL_SUCCESS;
}

HcclResult CollReduceScatterMeshGraphPipelineExecutor::CalcLevel0CommInfo(
    TransportMemType inputType, TransportMemType outputType, std::vector<LevelNSubCommTransport> &opTransport)
{
    CommParaInfo commParaInfo(COMM_LEVEL0, CommType::COMM_TAG_MESH);
    commParaInfo.meshSinglePlane = true;
    CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL0], inputType, outputType));
    return HCCL_SUCCESS;
}

// PipeLine模式下使用Ring算法
HcclResult CollReduceScatterMeshGraphPipelineExecutor::CalcLevel1CommInfo(
    TransportMemType inputType, TransportMemType outputType, std::vector<LevelNSubCommTransport> &opTransport)
{
    CommParaInfo commParaInfo(COMM_LEVEL1, CommType::COMM_TAG_RING_INNER);
    CHK_RET(CalcCommPlaneInfo(tag_, commParaInfo, opTransport[COMM_LEVEL1], inputType, outputType));
    return HCCL_SUCCESS;
}

HcclResult CollReduceScatterMeshGraphPipelineExecutor::KernelRun(const OpParam &param, ExecMem &execMem)
{
    HCCL_CONFIG_INFO(HCCL_ALG,
        "[CollReduceScatterMeshGraphPipelineExecutor][KernelRun] userRank[%u] starts.", topoAttr_.userRank);
    CHK_RET(CheckCommSize(COMM_LEVEL0, COMM_INDEX_0 + 1));
    SubCommInfo level0CommInfo = GetSubCommInfo(COMM_LEVEL0, COMM_INDEX_0);
    u32 commIndex = level0CommInfo.localRank;
    CHK_RET(CheckCommSize(COMM_LEVEL1, commIndex + 1));
    SubCommInfo level1CommInfo = GetSubCommInfo(COMM_LEVEL1, commIndex);

    u32 unitSize = SIZE_TABLE[param.DataDes.dataType];
    DeviceMem userInMem = DeviceMem::create(param.inputPtr, param.DataDes.count * unitSize);
    u64 reduceAttr = GetReduceAttr(userInMem, userInMem, param.DataDes.dataType, param.reduceType);

    CHK_RET(ActiveSlaveStreams(param.stream));

    std::unique_ptr<AlgTemplateBase> tempAlg = AlgTemplateRegistry::Instance().GetAlgTemplate(
        TemplateType::TEMPLATE_REDUCESCATTER_GRAPH_PIPELINE, dispatcher_);
    CHK_SMART_PTR_NULL(tempAlg);

    HcomCollOpInfo opInfo = {"",
        execMem.inputPtr,
        execMem.outputPtr,
        param.DataDes.count,
        param.DataDes.dataType,
        param.root,
        param.reduceType};

    CHK_RET(tempAlg->Prepare(&opInfo,
        execMem.inputMem,
        param.DataDes.count,
        0,
        0,
        level0CommInfo,
        level1CommInfo,
        const_cast<Stream &>(param.stream),
        algResResp_->slaveStreams,
        algResResp_->notifiesMain,
        algResResp_->notifiesAux,
        reduceAttr));
    CHK_RET(tempAlg->RunAsync());

    return HCCL_SUCCESS;
}

REGISTER_EXEC("ReduceScatterMeshGraphPipelineExecutor", ReduceScatterMeshGraphPipeline,
    CollReduceScatterMeshGraphPipelineExecutor);
}  // namespace hccl